#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import re
import os


def analyze_sql_file(file_path):
    """
    分析SQL文件，找出包含decimal(10,2)类型的字段，并生成修改语句
    """
    try:
        with open(file_path, 'r', encoding='utf-8') as file:
            content = file.read()
    except UnicodeDecodeError:
        # 如果UTF-8失败，尝试其他编码
        with open(file_path, 'r', encoding='gbk') as file:
            content = file.read()

    # 按分号分割SQL语句
    sql_statements = content.split(';')

    alter_statements = []

    for sql_statement in sql_statements:
        sql_statement = sql_statement.strip()

        # 跳过空语句
        if not sql_statement:
            continue

        # 检查是否是CREATE TABLE语句
        table_alter_statements = process_create_table_statement(sql_statement)
        alter_statements.extend(table_alter_statements)
        # if sql_statement.upper().startswith('CREATE TABLE'):


    return alter_statements


def process_create_table_statement(create_table_sql):
    """
    处理单个CREATE TABLE语句
    """
    alter_statements = []

    # 提取表名的正则表达式
    table_name_pattern = r'CREATE\s+TABLE\s+`([^`]+)`\s*\('
    table_name_match = re.search(table_name_pattern, create_table_sql, re.IGNORECASE | re.DOTALL)

    if not table_name_match:
        print(f"无法解析表名: {create_table_sql[:100]}...")
        return alter_statements

    table_name = table_name_match.group(1)

    # 查找所有decimal(10,2)字段的正则表达式
    decimal_field_pattern = r'`([^`]+)`\s+decimal\(12, 2\)'
    field_matches = re.findall(decimal_field_pattern, create_table_sql, re.IGNORECASE)

    for field_name in field_matches:
        # 生成ALTER语句
        alter_statement = f"ALTER TABLE {table_name} modify COLUMN {field_name} decimal(38,6) NULL;"
        alter_statements.append(alter_statement)

        print(f"发现需要修改的字段: 表={table_name}, 字段={field_name}")

    return alter_statements


def main():
    file_path = "C:\\Users\\56209\\PycharmProjects\\pythonProject\\sqlscript\\temp\\all_tables_ddl_starrocks-第二批.sql"

    if not os.path.exists(file_path):
        print(f"文件不存在: {file_path}")
        return

    try:
        alter_statements = analyze_sql_file(file_path)

        print("\n=== 生成的ALTER语句 ===")
        for statement in alter_statements:
            print(statement)

        print(f"\n总共找到 {len(alter_statements)} 个需要修改的字段")

        # 将结果写入文件
        output_file = "temp/alter_statements-第二批.sql"
        with open(output_file, 'w', encoding='utf-8') as f:
            f.write("-- 自动生成的ALTER语句\n")
            f.write("-- 将decimal(12,2)字段修改为decimal(38,6)\n")
            f.write(f"-- 生成时间: {__import__('datetime').datetime.now()}\n\n")

            for statement in alter_statements:
                f.write(statement + "\n")

        print(f"\nALTER语句已保存到文件: {output_file}")

    except Exception as e:
        print(f"处理文件时发生错误: {e}")


if __name__ == "__main__":
    main()
